Update cuda_specs.py#1833
Update cuda_specs.py#1833asmertpc-cloud wants to merge 5 commits intobitsandbytes-foundation:mainfrom
Conversation
fix for windows
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
TimDettmers
left a comment
There was a problem hiding this comment.
PR Review: #1833 — Update cuda_specs.py
Platform fix: attempts to make ROCm GPU detection work on Windows by switching rocminfo to hipinfo and adjusting warp-size regex parsing.
Blocking issues (2):
-
get_rocm_gpu_arch()will not work on Windows — This PR changes the subprocess command fromrocminfotohipinfoon Windows, but keeps the same regex patternName:\s+gfx(...). On Windows,hipinfo.exereports the GPU architecture undergcnArchName:, notName:. The regex will never matchhipinfooutput, soget_rocm_gpu_arch()will always return"unknown"on Windows. PR #1846 handles this correctly by using distinct regex patterns per platform (gcnArchName:\s+(gfx...)for Windows vsName:\s+gfx(...)for Linux). -
Superseded by PR #1846 — PR #1846 (by the same author) is a more comprehensive fix that addresses the same
cuda_specs.pychanges plus CMakeLists.txt build support,csrc/ops.cuhandcsrc/ops_hip.cuhWindows portability fixes, and NOMINMAX defines. PR #1843 (by a different author) also addresses the same problem. This PR (#1833) is the least complete of the three and has correctness issues that #1846 does not. I recommend closing this PR in favor of #1846.
Additional issues:
-
Import ordering:
import sysis placed afterimport torch, violating PEP 8 (stdlib imports before third-party). The linter would catch this. -
Module-level side effect: The
if sys.platform == "win32"block runs at module import time for all users, including non-ROCm users on Linux. While harmless (it just sets a string variable), it's cleaner to keep the platform check inside the functions that need it, as #1846 does. -
Warp-size regex fragility: The combined regex
(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?is fragile. The[0-9]{2}requires exactly two digits, which works for 32 and 64 but fails for any other value. PR #1846's approach of using separate, explicit patterns per platform is more maintainable. -
No tests — Bug fix PRs should include regression tests. However, testing subprocess-based hardware detection is difficult, so this is non-blocking.
-
CI has not run — No checks reported for this fork PR. A maintainer would need to approve the workflow run before merge.
-
Security: Clear (no new dangerous patterns;
subprocess.runwithhipinfo/rocminfois the existing pattern) -
Downstream impact: None (internal GPU detection, no public API changes)
-
Tests: Missing but non-blocking for hardware detection changes
-
CI: Not triggered (fork PR)
-
Cross-PR conflicts: Conflicts with #1846 and #1843 — all three modify
bitsandbytes/cuda_specs.pyfor the same purpose. #1846 is the most complete and correct. Recommend closing #1833 and #1843 in favor of #1846.
Verdict: Request changes (close in favor of #1846). The get_rocm_gpu_arch() regex does not match hipinfo.exe output format, so this PR does not achieve its stated goal on Windows. PR #1846 from the same author fixes this correctly and includes the necessary build system changes.
| if (sys.platform == "win32"): | ||
| rocminfo = "hipinfo" | ||
| else: | ||
| rocminfo = "rocminfo" |
There was a problem hiding this comment.
import sys should be placed before import torch (stdlib before third-party per PEP 8). Also, this module-level platform check runs for all users on import. Consider keeping the platform-conditional logic inside the functions that need it, as PR #1846 does with platform.system() == "Windows" checks.
| result = subprocess.run(["rocminfo"], capture_output=True, text=True) | ||
| match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) | ||
| result = subprocess.run([rocminfo], capture_output=True, text=True) | ||
| match = re.search(r"Name:\s+gfx([a-z\d]+)", result.stdout, re.IGNORECASE) |
There was a problem hiding this comment.
This regex Name:\s+gfx(...) will not match hipinfo.exe output on Windows. The hipinfo.exe utility reports GPU architecture under gcnArchName:, not Name:. For example, hipinfo outputs gcnArchName: gfx1100, not Name: gfx1100. This means get_rocm_gpu_arch() will always return "unknown" on Windows, defeating the purpose of this PR. PR #1846 handles this correctly with a separate gcnArchName: pattern for Windows.
| result = subprocess.run(["rocminfo"], capture_output=True, text=True) | ||
| match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) | ||
| result = subprocess.run([rocminfo], capture_output=True, text=True) | ||
| match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) |
There was a problem hiding this comment.
The combined regex (wavefront\s|warp)size: is creative but fragile. The [0-9]{2} in group 2 requires exactly two digits, which works for 32/64 but would break for unexpected values. More importantly, the third optional group (\([x0-9]{4}\))? requires exactly 4 characters inside parentheses — this happens to match (0x40) for warp size 64 but would fail for other hex representations. PR #1846's approach of using cleanly separated patterns (warpSize:\s+(\d+) for Windows, the existing Wavefront Size: pattern for Linux) is more robust and readable.
fix for windows